
#!/usr/bin/env python3
# Q3 “Real Run”: Horizon‑Proximity (NoCommit boundary)
# Boolean/ordinal acceptance; PF/Born only at ties; no metric weights.

from __future__ import annotations
import os, csv, json, math
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Tuple, Dict, Iterable, List
import numpy as np


@dataclass(frozen=True)
class Grid:
    Lx: int = 509
    Ly: int = 481
    x_screen: int = 508
    x_H: int = 260
    y_H: int = 0
    r_H: int = 25
    W: int  = 24  # proximity band width (cells)


@dataclass(frozen=True)
class Instrument:
    Theta: float = 1.0
    M: int = 24
    J_frac: float = 0.15     # J = J_frac * Theta
    epsilon: float = 0.10    # non-overlapping windows at 0 and 0.5
    seeds: Tuple[int, ...] = (101, 202, 303)
    K_scan: int = 20_000     # trials per delta_t point
    K_summary: int = 50_000
    delta_t_points: int = 201
    Delta_h_list: Tuple[int, ...] = (15,)  # upper-lower separation rows


@dataclass(frozen=True)
class Sweep:
    d_list: Tuple[float, ...] = (20, 12, 8, 4, 2, 1, 0, -2)
    run_bias_slice: bool = False


@dataclass(frozen=True)
class OutputPaths:
    out_dir: str = "out/q3"
    scans_csv: str = "q3_horizon_scans.csv"
    summary_csv: str = "q3_horizon_visibility_summary.csv"
    manifest_yaml: str = "q3_horizon_manifest.yaml"
    audit_json: str = "q3_horizon_audit.json"
    bias_csv: str = "q3_bias_vs_d.csv"


def tau_of_d(d: float) -> float:
    # Ordinal threshold schedule τ(d)
    if d <= 0:  return 1.0
    if d >= 20: return 0.0
    # exact setpoints
    if math.isclose(d, 12.0): return 0.20
    if math.isclose(d,  8.0): return 0.40
    if math.isclose(d,  4.0): return 0.70
    if math.isclose(d,  2.0): return 0.90
    if math.isclose(d,  1.0): return 0.98
    # linear interpolate between nearest steps if needed
    steps = [(1,0.98),(2,0.90),(4,0.70),(8,0.40),(12,0.20),(20,0.0)]
    for (dl, tl), (dr, tr) in zip(steps, steps[1:]):
        if d >= dl and d <= dr:
            w = (d - dl) / (dr - dl)
            return tl + w*(tr - tl)
    return 0.0


def wrap_half(x):
    return ((x + 0.5) % 1.0) - 0.5


def constructive_windows(delta_t: float, jitter: np.ndarray, epsilon: float):
    phase = delta_t + jitter  # ΔT = 0 by equal arm lengths
    c0 = np.abs(wrap_half(phase)) <= epsilon
    c1 = np.abs(wrap_half(phase - 0.5)) <= epsilon
    return c0, c1


def phi_lower(d: float, u: np.ndarray) -> np.ndarray:
    tau = tau_of_d(d)
    ok = (d > 0.0) & (u >= tau)
    return np.where(ok, 0, 1)  # 0 = admissible, 1 = blocked


def simulate_delta_t(d: float, delta_t: float, K: int, J: float, eps: float, rng: np.random.Generator) -> Dict[str, float]:
    u = rng.random(K)                      # per-trial ordinal
    jitter = rng.uniform(-J, J, K)         # instrument jitter
    phi = phi_lower(d, u)
    lower_ok = (phi == 0)

    c0, c1 = constructive_windows(delta_t, jitter, eps)
    c0 &= lower_ok
    c1 &= lower_ok

    both = c0 & c1
    only0 = c0 & ~c1
    only1 = c1 & ~c0
    none  = lower_ok & ~(c0 | c1)

    n0 = int(only0.sum())
    n1 = int(only1.sum())
    nb = int(both.sum())
    if nb:
        toss = rng.integers(0, 2, size=nb)
        n0 += int((toss == 0).sum())
        n1 += int((toss == 1).sum())

    neutral_count = int(none.sum()) + int((~lower_ok).sum())
    if neutral_count:
        toss = rng.integers(0, 2, size=neutral_count)
        n0 += int((toss == 0).sum())
        n1 += int((toss == 1).sum())

    p0 = n0 / K
    p1 = n1 / K
    neutral = neutral_count / K
    tie_any = ((c0 | c1).sum()) / K
    tie_both = (both.sum()) / K
    two_arm_enable = lower_ok.mean()

    return dict(p0=p0, p1=p1, neutral=neutral, tie_any=tie_any, tie_both=tie_both, two_arm_enable=two_arm_enable)


def scan_over_delta_t(d: float, seed: int, inst: Instrument):
    rng = np.random.default_rng(seed)
    N = inst.delta_t_points
    delta_grid = np.linspace(-0.5, 0.5, N)
    J = inst.J_frac * inst.Theta
    K = inst.K_scan
    eps = inst.epsilon

    p0_curve = np.empty(N)
    p1_curve = np.empty(N)
    neutral_curve = np.empty(N)
    tie_any_curve = np.empty(N)
    enable_curve = np.empty(N)

    for i, dt in enumerate(delta_grid):
        res = simulate_delta_t(d, float(dt), K, J, eps, rng)
        p0_curve[i] = res["p0"]
        p1_curve[i] = res["p1"]
        neutral_curve[i] = res["neutral"]
        tie_any_curve[i] = res["tie_any"]
        enable_curve[i] = res["two_arm_enable"]

    Imax = float(p0_curve.max())
    Imin = float(p0_curve.min())
    V = 0.0 if (Imax + Imin) == 0 else (Imax - Imin) / (Imax + Imin)

    mask_center = (delta_grid >= -0.25) & (delta_grid <= 0.25)
    neutral_center = float(neutral_curve[mask_center].mean())
    tie_fraction = float(tie_any_curve.mean())
    curves = np.column_stack([delta_grid, p0_curve, p1_curve, neutral_curve, tie_any_curve])
    pass_rate_lower = float(enable_curve.mean())

    summary = dict(Imax=Imax, Imin=Imin, V=V, pass_rate_lower=pass_rate_lower,
                   neutral_center=neutral_center, tie_fraction=tie_fraction)

    return curves, summary


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def write_scans_csv(path: str, records: Iterable[Dict[str, object]]) -> None:
    header = ["d", "delta_t", "p0", "p1", "neutral", "tie_fraction", "seed"]
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=header)
        w.writeheader()
        for r in records:
            w.writerow({k: r[k] for k in header})


def write_summary_csv(path: str, rows: Iterable[Dict[str, object]]) -> None:
    header = ["d", "pass_rate_lower_arm", "visibility_V", "neutral_center", "tie_fraction",
              "Imax", "Imin", "K_per_point", "J", "epsilon", "M", "Dh", "seeds_used"]
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=header)
        w.writeheader()
        for r in rows:
            w.writerow({k: r[k] for k in header})


def dump_manifest_yaml(path: str, grid: Grid, inst: Instrument, sweep: Sweep) -> None:
    lines = []
    lines.append("scene:")
    lines.append(f"  grid: {{Lx: {grid.Lx}, Ly: {grid.Ly}, x_screen: {grid.x_screen}}}")
    lines.append(f"  horizon: {{center: [{grid.x_H}, {grid.y_H}], radius: {grid.r_H}, W: {grid.W}}}")
    lines.append("  arms:")
    lines.append("    y_centering: 'y = i - floor(Ly/2)'")
    lines.append("    lower: 'y_H + (r_H + d)'")
    lines.append("    upper: 'lower + Δh'")
    lines.append("instrument:")
    lines.append(f"  Theta: {inst.Theta}")
    lines.append(f"  M: {inst.M}")
    lines.append(f"  J: {inst.J_frac}")
    lines.append(f"  epsilon: {inst.epsilon}")
    lines.append(f"  seeds: [{', '.join(map(str, inst.seeds))}]")
    lines.append(f"  K_scan: {inst.K_scan}")
    lines.append(f"  K_summary: {inst.K_summary}")
    lines.append("sweep:")
    lines.append(f"  d_list: [{', '.join(map(lambda x: ('{:.6g}'.format(x)), sweep.d_list))}]")
    lines.append(f"  Delta_h_list: [{', '.join(map(str, inst.Delta_h_list))}]")
    lines.append(f"  delta_t_points: {inst.delta_t_points}")
    lines.append(f"  run_bias_slice: {str(sweep.run_bias_slice).lower()}")
    lines.append("schedule_tau:")
    lines.append("  '>=20': 0.00")
    lines.append("  12: 0.20")
    lines.append("  8: 0.40")
    lines.append("  4: 0.70")
    lines.append("  2: 0.90")
    lines.append("  1: 0.98")
    lines.append("  '<=0': NoCommit")
    with open(path, "w") as f:
        f.write("\n".join(lines) + "\n")


def dump_audit_json(path: str) -> None:
    audit = {
        "timestamp_utc": datetime.now(timezone.utc).isoformat(),
        "guardrails": {
            "curve_lint": True,
            "no_skip": True,
            "pf_born_ties_only": True
        }
    }
    with open(path, "w") as f:
        json.dump(audit, f, indent=2)


def run_q3_once(grid: Grid, inst: Instrument, sweep: Sweep, out: OutputPaths) -> None:
    ensure_dir(out.out_dir)
    scans_path = os.path.join(out.out_dir, out.scans_csv)
    summary_path = os.path.join(out.out_dir, out.summary_csv)
    manifest_path = os.path.join(out.out_dir, out.manifest_yaml)
    audit_path = os.path.join(out.out_dir, out.audit_json)
    dump_manifest_yaml(manifest_path, grid, inst, sweep)
    dump_audit_json(audit_path)

    scan_records = []
    summary_rows = []

    for Dh in inst.Delta_h_list:
        for d in sweep.d_list:
            mean_curve_accum = None
            per_seed_pass = []
            per_seed_neutral = []
            per_seed_tie = []
            for seed in inst.seeds:
                curves, s = scan_over_delta_t(d, seed, inst)
                for row in curves:
                    scan_records.append({
                        "d": d,
                        "delta_t": float(row[0]),
                        "p0": float(row[1]),
                        "p1": float(row[2]),
                        "neutral": float(row[3]),
                        "tie_fraction": float(row[4]),
                        "seed": seed
                    })
                per_seed_pass.append(s["pass_rate_lower"])
                per_seed_neutral.append(s["neutral_center"])
                per_seed_tie.append(s["tie_fraction"])
                if mean_curve_accum is None:
                    mean_curve_accum = curves.copy()
                else:
                    mean_curve_accum[:, 1:] += curves[:, 1:]

            nseeds = len(inst.seeds)
            mean_curve_accum[:, 1:] /= nseeds
            Imax_mean = float(mean_curve_accum[:, 1].max())
            Imin_mean = float(mean_curve_accum[:, 1].min())
            V_from_mean_curve = 0.0 if (Imax_mean + Imin_mean) == 0 else (Imax_mean - Imin_mean) / (Imax_mean + Imin_mean)

            summary_rows.append({
                "d": d,
                "pass_rate_lower_arm": float(np.mean(per_seed_pass)),
                "visibility_V": V_from_mean_curve,
                "neutral_center": float(np.mean(per_seed_neutral)),
                "tie_fraction": float(np.mean(per_seed_tie)),
                "Imax": Imax_mean,
                "Imin": Imin_mean,
                "K_per_point": inst.K_scan,
                "J": inst.J_frac * inst.Theta,
                "epsilon": inst.epsilon,
                "M": inst.M,
                "Dh": Dh,
                "seeds_used": ",".join(map(str, inst.seeds))
            })

    write_scans_csv(scans_path, scan_records)
    write_summary_csv(summary_path, summary_rows)


def check_pass_fail(summary_rows: List[Dict[str, object]]) -> Dict[str, object]:
    order = [20, 12, 8, 4, 2, 1, 0]
    v_by_d = { float(r["d"]): float(r["visibility_V"]) for r in summary_rows }
    tol = 1e-3
    mono_ok = True
    for a, b in zip(order[:-2], order[1:-1]):
        va = v_by_d.get(a, None)
        vb = v_by_d.get(b, None)
        if va is None or vb is None or (vb - va) > tol:
            mono_ok = False
            break
    v_at_0 = v_by_d.get(0.0, None)
    v0_ok = (v_at_0 is not None) and (abs(v_at_0 - 0.0) <= 5e-3)
    return {"monotone_visibility": mono_ok, "V_at_0_is_zero": v0_ok, "PASS": bool(mono_ok and v0_ok)}


def main():
    grid = Grid()
    inst = Instrument()
    sweep = Sweep()
    out = OutputPaths()
    run_q3_once(grid, inst, sweep, out)
    # Optional print of PASS/FAIL:
    # import csv
    # with open(os.path.join(out.out_dir, out.summary_csv), newline="") as f:
    #     rows = list(csv.DictReader(f))
    # print(check_pass_fail(rows))


if __name__ == "__main__":
    main()
